{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-04 13:39:39.808055: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-03-04 13:39:45.632101: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2023-03-04 13:39:47.031019: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2023-03-04 13:39:47.031037: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n",
      "2023-03-04 13:39:58.154757: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2023-03-04 13:39:58.154849: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2023-03-04 13:39:58.154858: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_compression as tfc\n",
    "import tensorflow_probability as tfp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from time import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(tf.keras.layers.Layer):\n",
    "    \"\"\"Encoder network for the VAE.\"\"\"\n",
    "    \n",
    "    def __init__(self, N):\n",
    "        \"\"\"Initializes the encoder.\"\"\"\n",
    "        \n",
    "        super(Decoder, self).__init__()\n",
    "        self.N      = N\n",
    "        self.conv2  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2)\n",
    "        self.conv1  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2)\n",
    "        self.conv3  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2)\n",
    "        self.conv4  = tf.keras.layers.Conv2DTranspose(3, 5, strides=2)\n",
    "        self.gdn1   = tfc.layers.GDN(inverse=True)\n",
    "        self.gdn2   = tfc.layers.GDN(inverse=True)\n",
    "        self.gdn3   = tfc.layers.GDN(inverse=True)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        \"\"\"Forward pass of the decoder.\"\"\"\n",
    "        x = self.conv1(inputs)\n",
    "        x = self.gdn1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.gdn2(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.gdn3(x)\n",
    "        z = self.conv4(x)\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def indexes(i):\n",
    "    return i\n",
    "\n",
    "def get_indexed_emodel(num_scales):\n",
    "    return tfc.LocationScaleIndexedEntropyModel(\n",
    "        prior_fn=tfc.NoisyNormal,\n",
    "        num_scales=num_scales,\n",
    "        scale_fn = indexes,\n",
    "        coding_rank=1,\n",
    "     )\n",
    "\n",
    "def get_batched_emodel(batch_shape=()):\n",
    "    return tfc.ContinuousBatchedEntropyModel(\n",
    "        prior=tfc.distributions.NoisyDeepFactorized(batch_shape=batch_shape),\n",
    "        coding_rank=1\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 0.6269213  -0.28662133]\n",
      " [ 1.0749947   0.8063948 ]], shape=(2, 2), dtype=float32)\n",
      "tf.Tensor([2.0862703 4.9217143], shape=(2,), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# bemodel = get_batched_emodel((19, 19))\n",
    "iemodel = get_indexed_emodel(19)\n",
    "\n",
    "# y_tilde, rate_i = bemodel\n",
    "y     = tf.random.uniform((2,2))\n",
    "sigma = 0.5\n",
    "y_tilde, rate_i = iemodel(\n",
    "                        y, \n",
    "                        sigma, \n",
    "                        training=True,\n",
    "                        )\n",
    "\n",
    "print(y_tilde)\n",
    "print(rate_i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "220aeacb52331f97f7780d84158167722a7cc0056b78b1c27b32a8f95243ca77"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
